function [q_hat_partial_dt] = time_derivative_q_hat(model_sol, q_hat, lambdaB)
% this function returns the matrix of q_hat derivatives, after updating
% from the last round q_hat functions
par = model_sol.par; 
w_matrix = model_sol.w_matrix; 
lambda_matrix=model_sol.lambda_matrix; 

muw = model_sol.muw_matrix;
mu_lambda = par.mu_lambda_fun(lambda_matrix);
kappa_lambda = par.kappa_lambda_fun(lambda_matrix);
xK = model_sol.xK_matrix;
sigmaw = model_sol.sigmaw_matrix;
sigmap = model_sol.sigmap_matrix;
rd = model_sol.rd_matrix;
kappaw = model_sol.kappaw_matrix;

q_hat_new_matrix = q_hat( w_matrix.*(1-kappaw), lambda_matrix+kappa_lambda ) ;
q_hat_matrix = q_hat( w_matrix, lambda_matrix );

[q_partial_w, q_partial_lambda] = derivative2D( par.w_grid, par.lambda_grid, q_hat_matrix ); 
[ q_2D_w, ~ ] =  derivative2D( par.w_grid, par.lambda_grid, q_partial_w ); 

q_hat_partial_dt = lambda_matrix.*( q_hat_new_matrix - q_hat_matrix ) ./ (1-kappaw) ...
    + lambdaB.*( q_hat_matrix-1 ) + xK.*(par.sigmaK+sigmap) .* q_partial_w .* w_matrix .* sigmaw ...
    -  (   q_partial_w.*w_matrix.*muw +  q_partial_lambda.*mu_lambda + 0.5* q_2D_w .* w_matrix.^2 .* sigmaw.^2 - rd.*q_hat_matrix );


% q_hat_partial_dt =  lambdaB.*( q_hat_matrix-1 ) + xK.*(par.sigmaK+sigmap) .* q_partial_w .* w_matrix .* sigmaw ...
%     -  (   q_partial_w.*w_matrix.*muw +  q_partial_lambda.*mu_lambda + 0.5* q_2D_w .* w_matrix.^2 .* sigmaw.^2 - rd.*q_hat_matrix );
% 



